# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for decoder factory functions."""

from absl.testing import parameterized
import tensorflow as tf, tf_keras

from tensorflow.python.distribute import combinations
from official.vision import configs
from official.vision.configs import decoders as decoders_cfg
from official.vision.modeling import decoders
from official.vision.modeling.decoders import factory


class FactoryTest(tf.test.TestCase, parameterized.TestCase):

  @combinations.generate(
      combinations.combine(
          num_filters=[128, 256], use_separable_conv=[True, False]))
  def test_fpn_decoder_creation(self, num_filters, use_separable_conv):
    """Test creation of FPN decoder."""
    min_level = 3
    max_level = 7
    input_specs = {}
    for level in range(min_level, max_level):
      input_specs[str(level)] = tf.TensorShape(
          [1, 128 // (2**level), 128 // (2**level), 3])

    network = decoders.FPN(
        input_specs=input_specs,
        num_filters=num_filters,
        use_separable_conv=use_separable_conv,
        use_sync_bn=True)

    model_config = configs.retinanet.RetinaNet()
    model_config.min_level = min_level
    model_config.max_level = max_level
    model_config.num_classes = 10
    model_config.input_size = [None, None, 3]
    model_config.decoder = decoders_cfg.Decoder(
        type='fpn',
        fpn=decoders_cfg.FPN(
            num_filters=num_filters, use_separable_conv=use_separable_conv))

    factory_network = factory.build_decoder(
        input_specs=input_specs, model_config=model_config)

    network_config = network.get_config()
    factory_network_config = factory_network.get_config()

    self.assertEqual(network_config, factory_network_config)

  @combinations.generate(
      combinations.combine(
          num_filters=[128, 256],
          num_repeats=[3, 5],
          use_separable_conv=[True, False]))
  def test_nasfpn_decoder_creation(self, num_filters, num_repeats,
                                   use_separable_conv):
    """Test creation of NASFPN decoder."""
    min_level = 3
    max_level = 7
    input_specs = {}
    for level in range(min_level, max_level):
      input_specs[str(level)] = tf.TensorShape(
          [1, 128 // (2**level), 128 // (2**level), 3])

    network = decoders.NASFPN(
        input_specs=input_specs,
        num_filters=num_filters,
        num_repeats=num_repeats,
        use_separable_conv=use_separable_conv,
        use_sync_bn=True)

    model_config = configs.retinanet.RetinaNet()
    model_config.min_level = min_level
    model_config.max_level = max_level
    model_config.num_classes = 10
    model_config.input_size = [None, None, 3]
    model_config.decoder = decoders_cfg.Decoder(
        type='nasfpn',
        nasfpn=decoders_cfg.NASFPN(
            num_filters=num_filters,
            num_repeats=num_repeats,
            use_separable_conv=use_separable_conv))

    factory_network = factory.build_decoder(
        input_specs=input_specs, model_config=model_config)

    network_config = network.get_config()
    factory_network_config = factory_network.get_config()

    self.assertEqual(network_config, factory_network_config)

  @combinations.generate(
      combinations.combine(
          level=[3, 4],
          dilation_rates=[[6, 12, 18], [6, 12]],
          num_filters=[128, 256]))
  def test_aspp_decoder_creation(self, level, dilation_rates, num_filters):
    """Test creation of ASPP decoder."""
    input_specs = {'1': tf.TensorShape([1, 128, 128, 3])}

    network = decoders.ASPP(
        level=level,
        dilation_rates=dilation_rates,
        num_filters=num_filters,
        use_sync_bn=True)

    model_config = configs.semantic_segmentation.SemanticSegmentationModel()
    model_config.num_classes = 10
    model_config.input_size = [None, None, 3]
    model_config.decoder = decoders_cfg.Decoder(
        type='aspp',
        aspp=decoders_cfg.ASPP(
            level=level, dilation_rates=dilation_rates,
            num_filters=num_filters))

    factory_network = factory.build_decoder(
        input_specs=input_specs, model_config=model_config)

    network_config = network.get_config()
    factory_network_config = factory_network.get_config()
    # Due to calling `super().get_config()` in aspp layer, everything but the
    # the name of two layer instances are the same, so we force equal name so it
    # will not give false alarm.
    factory_network_config['name'] = network_config['name']

    self.assertEqual(network_config, factory_network_config)

  def test_identity_decoder_creation(self):
    """Test creation of identity decoder."""
    model_config = configs.retinanet.RetinaNet()
    model_config.num_classes = 2
    model_config.input_size = [None, None, 3]

    model_config.decoder = decoders_cfg.Decoder(
        type='identity', identity=decoders_cfg.Identity())

    factory_network = factory.build_decoder(
        input_specs=None, model_config=model_config)

    self.assertIsNone(factory_network)


if __name__ == '__main__':
  tf.test.main()
